#!/usr/bin/env python
# coding: utf-8


import pandas as pd
import numpy as np
import os
from python_speech_features import mfcc
from collections import Counter
from gmm_code.utils import read_wav

def get_feature(fs ,signal):
    '''
    此函数用于提取mfcc特征，其中fs为采样率，signal为音频数据
    '''
    mfcc_feature = mfcc(signal, int(fs))
    #print("-------")
    #print(mfcc_feature)
    if len(mfcc_feature) == 0:
        print("ERROR.. failed to extract mfcc feature:", len(signal))
    return mfcc_feature


def deal_mfcc(path, min_shape):
    '''
    这里是把所有mfcc都转换成一个维度，取最大的维度，然后根据最小的维度删减数据.
    path为音频数据的路径，max_shape为mfcc的最小维度
    '''
    path_list = os.listdir(path)
    data_zero = np.array(np.zeros((1, min_shape, 13))).astype(float)
    
    for i in range(len(path_list)):
        txt_path = os.path.join(path, path_list[i])
        #print(txt_path)
        df = pd.read_csv(txt_path ,header=None).iloc[:,1:]
        sig = np.array(df.values.T[0][1:], dtype = np.uint16)
        #print(sig,type(sig))
        mfcc_df = get_feature(16000 , sig)
        mfcc_df = np.nan_to_num(mfcc_df)
        #print(mfcc_df.shape)
        
        if mfcc_df.shape[0] == min_shape:
            mfcc_data = np.reshape(mfcc_df, (1, min_shape,13))
            data_zero = np.append(data_zero, mfcc_data, axis=0)

        elif mfcc_df.shape[0] > min_shape:#大于最小纬度，就对后面的数据进行删减
            sert = int(mfcc_df.shape[0] - min_shape)
            #zero_insert = np.array(np.zeros((sert, 13))).astype(float)
            for j in range(sert):
                mfcc_df = np.delete(mfcc_df, -1,axis=0)
            mfcc_data = np.reshape(mfcc_df,(1,min_shape,13))
            data_zero = np.append(data_zero, mfcc_data, axis=0)
        
    data_zero = np.delete(data_zero,0,axis=0)   
    #print('deal shape:',data_zero.shape) 
    return data_zero


def create_label(label_len,label_name):
    '''
    此函数用于构建一个n维的label
    '''
    label = np.array(np.zeros((label_len,))).astype(float)
    label[label==0.] = label_name
    return label


import keras
import matplotlib.pyplot as plt
from keras.callbacks import LambdaCallback
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM, Flatten
from keras.optimizers import RMSprop, Adam, Nadam,Adamax
from keras.utils.data_utils import get_file
import numpy as np
import random
import sys
import io
from keras.layers import Input, Dense, LSTM, RepeatVector, Reshape, Permute,Bidirectional
from keras.models import Model
from keras.layers import Bidirectional, concatenate, Conv1D, MaxPooling1D, GlobalMaxPooling1D, Dropout, BatchNormalization
import gc
from keras.utils import plot_model, np_utils
import time
import datetime
from keras.utils import to_categorical
from keras.callbacks import EarlyStopping
import warnings
from sklearn.model_selection import train_test_split
from numpy import random, mat
import multiprocessing
import os
import time
import math
from keras.utils import plot_model
warnings.filterwarnings('ignore')


num_sleep_states = 2#二分类
#num_sleep_states = 3 #三分类
samplecount = 1500


def model_manydim():
    print('Build model...')

    X1 = Input(shape=(495,13))
    #X2 = Input(shape=(3733,))
    #X3 = Input(shape=(622,))

    shared1_Conv1D = Conv1D(64, 64, activation='relu')
    shared2_Conv1D = Conv1D(64, 32, activation='relu')
    shared3_Conv1D = Conv1D(64, 16, activation='relu')
    shared4_Conv1D = Conv1D(64, 8, activation='relu')
    shared5_Conv1D = Conv1D(64, 4, activation='relu')

    shared_Conv1D = Sequential()

    shared_Conv1D.add(shared1_Conv1D)
    shared_Conv1D.add(MaxPooling1D(8))
    # shared_Conv1D.add(Dropout(0.2))

    shared_Conv1D.add(shared2_Conv1D)
    shared_Conv1D.add(MaxPooling1D(4))
    # shared_Conv1D.add(Dropout(0.2))

    shared_Conv1D.add(shared3_Conv1D)
    shared_Conv1D.add(MaxPooling1D(4))
    # shared_Conv1D.add(Dropout(0.2))

    shared_Conv1D.add(shared4_Conv1D)
    shared_Conv1D.add(MaxPooling1D(2))
    # shared_Conv1D.add(Dropout(0.2))

    shared_Conv1D.add(shared5_Conv1D)

    shared_Conv1D.add(GlobalMaxPooling1D())

    shared1_Conv1D_tmp = Conv1D(64, 16, activation='relu')
    shared2_Conv1D_tmp = Conv1D(64, 8, activation='relu')
    #shared3_Conv1D_tmp = Conv1D(64, 8, activation='relu')
    #shared4_Conv1D_tmp = Conv1D(64, 4, activation='relu')
    #shared5_Conv1D_tmp = Conv1D(64, 2, activation='relu')

    shared_Conv1D_tmp = Sequential()
    shared_Conv1D_tmp.add(shared1_Conv1D_tmp)
    shared_Conv1D_tmp.add(MaxPooling1D(4))
    shared_Conv1D_tmp.add(Dropout(0.2))
    shared_Conv1D_tmp.add(shared2_Conv1D_tmp)
    shared_Conv1D_tmp.add(MaxPooling1D(2))
    shared_Conv1D_tmp.add(Dropout(0.2))
    #shared_Conv1D_tmp.add(shared3_Conv1D_tmp)
    #shared_Conv1D_tmp.add(MaxPooling1D(2))
    # shared_Conv1D_tmp.add(Dropout(0.2))
    #shared_Conv1D_tmp.add(shared4_Conv1D_tmp)
    #shared_Conv1D_tmp.add(MaxPooling1D(2))
    # shared_Conv1D_tmp.add(Dropout(0.2))
    #shared_Conv1D_tmp.add(shared5_Conv1D_tmp)
    shared_Conv1D_tmp.add(GlobalMaxPooling1D())

    #tmpx1 = RepeatVector(1)(X1)
    #tmpx2 = RepeatVector(1)(X2)
    #tmpx3 = RepeatVector(1)(X3)
    #tmpx1 = Permute((2, 1))(tmpx1)
    #tmpx2 = Permute((2, 1))(tmpx2)
    #tmpx3 = Permute((2, 1))(tmpx3)

    tmpx1 = shared_Conv1D_tmp(X1)
    #tmpx2 = shared_Conv1D(tmpx2)
    #tmpx3 = shared_Conv1D_tmp(tmpx3)

    #
    tmpx1 = RepeatVector(1)(tmpx1)
    #tmpx2 = RepeatVector(1)(tmpx2)
    #tmpx3 = RepeatVector(1)(tmpx3)

    #tmpx = concatenate([tmpx1, tmpx2], axis=1)
    #tmpx = LSTM(256, activation='relu',return_sequences=True)(tmpx)
    tmpx = Flatten()(tmpx1)
    #tmpx = Dense(300, activation='relu')(tmpx)  # vip
    tmpx = Dense(64, activation='relu')(tmpx)
    #     tmpx = Dense(20,activation='relu')(tmpx)
    tmpx = Dropout(0.5)(tmpx)
    Y = Dense(num_sleep_states, activation='softmax')(tmpx)
    #Y = Dense(num_sleep_states, activation='sigmoid')(tmpx)
    #model = Model(inputs=[X1, X2], outputs=Y)
    model = Model(inputs=X1, outputs=Y)


    optimizer1 = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
    model.compile(loss='categorical_crossentropy', optimizer=optimizer1, metrics=['accuracy'])

    print('print model...')
    model.summary()
    # plot_model(model,to_file='model_Conv1D_final.png')
    return model


def del_file(path):
    ls = os.listdir(path)
    for i in ls:
        c_path = os.path.join(path, i)
        if os.path.isdir(c_path):
           del_file(c_path)
        else:
            os.remove(c_path)
def shuffleindex(trainTexs):
    trainsamplecount = trainTexs.shape[0]
    shuffleindex = [i for i in range(trainsamplecount)]
    random.shuffle(shuffleindex)
    trainTexs = trainTexs[shuffleindex]
    return trainTexs
def getpathfiles(mypath, mylist):
    for root, dirs, files in os.walk(mypath):
        for file in files:
            filename, type = os.path.splitext(file)
            if (type == ".csv"):
                mylist.append(os.path.join(root, file))
    mylist = np.array(mylist)
    trainsamplecount = mylist.shape[0]
    shuffleindex = [i for i in range(trainsamplecount)]
    random.shuffle(shuffleindex)
    mylist = mylist[shuffleindex]
    return mylist

class Save(keras.callbacks.Callback):
    def __init__(self):
        self.min_loss = 1.0
        tmodelsavepath = mysavemodelpath
        tmymodelfiles = list()
        tmymodelfiles = getpathfiles(tmodelsavepath, tmymodelfiles)
        if (len(tmymodelfiles) > 0):
            mname = tmymodelfiles[0]  # kears_model_05_acc=0.7319999933242798.h5
            i1 = mname.find("_loss=")
            i2 = mname.find(".h5")
            scroestr = mname[i1 + 1:i2]
            self.max_acc = float(scroestr)

    def on_epoch_begin(self, epoch, logs=None):
        pass

    def on_epoch_end(self, epoch, logs=None):
        self.val_loss = logs["val_loss"]
        self.acc = logs["val_acc"]
        if epoch != 0:
            if self.val_loss < self.min_loss:
                ename = str(epoch)
                if (epoch > 0):
                    ename = "0" + ename
                sname = str(samplecount) + "_"
                del_file(mysavemodelpath)
                model.save(mysavemodelpath + "/model_" + ename + "_acc=" + str(self.acc) + ".h5")
                self.min_loss = self.val_loss
                print("kears_model_ Save")


def deal_data(train_date_c1s,train_date_rs):
    '''
    数据标准化
    '''
    train_date_c1s = train_date_c1s.astype(float)
    train_date_rs = train_date_rs.astype(float)
    trainrs_stat = to_categorical(train_date_rs, num_classes=2)#将标签作为one-hot，num__class为分类数量

    vgac1 = np.mean(train_date_c1s)
    varc1 = math.sqrt(np.var(train_date_c1s))
    train_date_c1s = (train_date_c1s - vgac1) / varc1

    return train_date_c1s,trainrs_stat


if __name__ == '__main__':
    #数据准备
    path_linkunling = '/home/banana/bear_voice/src/mfcc_tcnn_speaker/script/human_txt/linkunling1/'
    linkunling_data = deal_mfcc(path_linkunling ,495)
    linkunling_label = create_label(linkunling_data.shape[0] , 1)
    #path_jiefei = '/home/banana/bear_voice/src/mfcc_tcnn_speaker/script/human_txt/jiefei/'
    #jiefei_data = deal_mfcc(path_jiefei,499)
    #jiefei_label = create_label(jiefei_data.shape[0] , 1)
    path_unknown = '/home/banana/bear_voice/src/mfcc_tcnn_speaker/script/human_txt/unknown_all/'
    unknown_data = deal_mfcc(path_unknown,495)
    unknown_label = create_label(unknown_data.shape[0] , 0)
    #data = np.concatenate((jiefei_data, unknown_data, linkunling_data),axis=0)#三分类
    #label = np.concatenate((jiefei_label, unknown_label, linkunling_label),axis=0)#三分类
    data = np.concatenate((unknown_data, linkunling_data),axis=0)
    label = np.concatenate((unknown_label, linkunling_label),axis=0)
    #标准化数据
    X_train, X_test, y_train, y_test = train_test_split(data, label, test_size=0.2, random_state=12)
    train_1,train_Y = deal_data(X_train,y_train)
    test_1,test_Y = deal_data(X_test,y_test) 
    #保存模型
    mysavemodelpath = '/home/banana/bear_voice/src/mfcc_tcnn_speaker/script/model'  
    model = model_manydim()
    save_function = Save()
    history = model.fit(train_1, train_Y,
			  batch_size=128,  # 100 is vip
			  epochs=500,
			  shuffle=True,
			  callbacks=[save_function],
			  validation_data=(test_1, test_Y))
      


#画图
#plt.plot(history.history['acc'])
#plt.plot(history.history['val_acc'])
#plt.title('Model accuracy')
#plt.ylabel('Accuracy')
#plt.xlabel('Epoch')
#plt.legend(['Train', 'Test'], loc='upper left')
#plt.savefig('./acc.jpg')
#plt.show()

# 绘制训练 & 验证的损失值
#fig = plt.figure()
#plt.plot(history.history['loss'])
#plt.plot(history.history['val_loss'])
#plt.title('Model loss')
#plt.ylabel('Loss')
#plt.xlabel('Epoch')
#plt.legend(['Train', 'Test'], loc='upper left')
#plt.savefig('./loss.jpg')
#plt.show()





